{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "colab": {
   "name": "GPT-J-6B Inference Demo.ipynb",
   "provenance": [],
   "collapsed_sections": [],
   "toc_visible": true,
   "machine_shape": "hm"
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  },
  "language_info": {
   "name": "python"
  },
  "accelerator": "TPU",
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "5b8a31b3b4034116af0a81d897c3123b": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_view_name": "HBoxView",
      "_dom_classes": [],
      "_model_name": "HBoxModel",
      "_view_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_view_count": null,
      "_view_module_version": "1.5.0",
      "box_style": "",
      "layout": "IPY_MODEL_835ee2070c05443c9fa3ae6fece3fb44",
      "_model_module": "@jupyter-widgets/controls",
      "children": [
       "IPY_MODEL_9e6646f7066341ef9619f5f7e6f7c8b8",
       "IPY_MODEL_4f803e3d56b74e18997b216bf9d65337"
      ]
     }
    },
    "835ee2070c05443c9fa3ae6fece3fb44": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_view_name": "LayoutView",
      "grid_template_rows": null,
      "right": null,
      "justify_content": null,
      "_view_module": "@jupyter-widgets/base",
      "overflow": null,
      "_model_module_version": "1.2.0",
      "_view_count": null,
      "flex_flow": null,
      "width": null,
      "min_width": null,
      "border": null,
      "align_items": null,
      "bottom": null,
      "_model_module": "@jupyter-widgets/base",
      "top": null,
      "grid_column": null,
      "overflow_y": null,
      "overflow_x": null,
      "grid_auto_flow": null,
      "grid_area": null,
      "grid_template_columns": null,
      "flex": null,
      "_model_name": "LayoutModel",
      "justify_items": null,
      "grid_row": null,
      "max_height": null,
      "align_content": null,
      "visibility": null,
      "align_self": null,
      "height": null,
      "min_height": null,
      "padding": null,
      "grid_auto_rows": null,
      "grid_gap": null,
      "max_width": null,
      "order": null,
      "_view_module_version": "1.2.0",
      "grid_template_areas": null,
      "object_position": null,
      "object_fit": null,
      "grid_auto_columns": null,
      "margin": null,
      "display": null,
      "left": null
     }
    },
    "9e6646f7066341ef9619f5f7e6f7c8b8": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_view_name": "ProgressView",
      "style": "IPY_MODEL_f9fa7c1bc9f04df09e487b16761d6b8f",
      "_dom_classes": [],
      "description": "Downloading: 100%",
      "_model_name": "FloatProgressModel",
      "bar_style": "success",
      "max": 1042301,
      "_view_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "value": 1042301,
      "_view_count": null,
      "_view_module_version": "1.5.0",
      "orientation": "horizontal",
      "min": 0,
      "description_tooltip": null,
      "_model_module": "@jupyter-widgets/controls",
      "layout": "IPY_MODEL_39aed05f694a495abd03aa00ff71efb3"
     }
    },
    "4f803e3d56b74e18997b216bf9d65337": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_view_name": "HTMLView",
      "style": "IPY_MODEL_61b0857802184b51abf77039fad681d7",
      "_dom_classes": [],
      "description": "",
      "_model_name": "HTMLModel",
      "placeholder": "​",
      "_view_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "value": " 1.04M/1.04M [00:07&lt;00:00, 148kB/s]",
      "_view_count": null,
      "_view_module_version": "1.5.0",
      "description_tooltip": null,
      "_model_module": "@jupyter-widgets/controls",
      "layout": "IPY_MODEL_0fb689df0a8b42c2b87e7f11b818b88a"
     }
    },
    "f9fa7c1bc9f04df09e487b16761d6b8f": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_view_name": "StyleView",
      "_model_name": "ProgressStyleModel",
      "description_width": "initial",
      "_view_module": "@jupyter-widgets/base",
      "_model_module_version": "1.5.0",
      "_view_count": null,
      "_view_module_version": "1.2.0",
      "bar_color": null,
      "_model_module": "@jupyter-widgets/controls"
     }
    },
    "39aed05f694a495abd03aa00ff71efb3": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_view_name": "LayoutView",
      "grid_template_rows": null,
      "right": null,
      "justify_content": null,
      "_view_module": "@jupyter-widgets/base",
      "overflow": null,
      "_model_module_version": "1.2.0",
      "_view_count": null,
      "flex_flow": null,
      "width": null,
      "min_width": null,
      "border": null,
      "align_items": null,
      "bottom": null,
      "_model_module": "@jupyter-widgets/base",
      "top": null,
      "grid_column": null,
      "overflow_y": null,
      "overflow_x": null,
      "grid_auto_flow": null,
      "grid_area": null,
      "grid_template_columns": null,
      "flex": null,
      "_model_name": "LayoutModel",
      "justify_items": null,
      "grid_row": null,
      "max_height": null,
      "align_content": null,
      "visibility": null,
      "align_self": null,
      "height": null,
      "min_height": null,
      "padding": null,
      "grid_auto_rows": null,
      "grid_gap": null,
      "max_width": null,
      "order": null,
      "_view_module_version": "1.2.0",
      "grid_template_areas": null,
      "object_position": null,
      "object_fit": null,
      "grid_auto_columns": null,
      "margin": null,
      "display": null,
      "left": null
     }
    },
    "61b0857802184b51abf77039fad681d7": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_view_name": "StyleView",
      "_model_name": "DescriptionStyleModel",
      "description_width": "",
      "_view_module": "@jupyter-widgets/base",
      "_model_module_version": "1.5.0",
      "_view_count": null,
      "_view_module_version": "1.2.0",
      "_model_module": "@jupyter-widgets/controls"
     }
    },
    "0fb689df0a8b42c2b87e7f11b818b88a": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_view_name": "LayoutView",
      "grid_template_rows": null,
      "right": null,
      "justify_content": null,
      "_view_module": "@jupyter-widgets/base",
      "overflow": null,
      "_model_module_version": "1.2.0",
      "_view_count": null,
      "flex_flow": null,
      "width": null,
      "min_width": null,
      "border": null,
      "align_items": null,
      "bottom": null,
      "_model_module": "@jupyter-widgets/base",
      "top": null,
      "grid_column": null,
      "overflow_y": null,
      "overflow_x": null,
      "grid_auto_flow": null,
      "grid_area": null,
      "grid_template_columns": null,
      "flex": null,
      "_model_name": "LayoutModel",
      "justify_items": null,
      "grid_row": null,
      "max_height": null,
      "align_content": null,
      "visibility": null,
      "align_self": null,
      "height": null,
      "min_height": null,
      "padding": null,
      "grid_auto_rows": null,
      "grid_gap": null,
      "max_width": null,
      "order": null,
      "_view_module_version": "1.2.0",
      "grid_template_areas": null,
      "object_position": null,
      "object_fit": null,
      "grid_auto_columns": null,
      "margin": null,
      "display": null,
      "left": null
     }
    },
    "8b76924bdac749d5b4908678fd0a7497": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_view_name": "HBoxView",
      "_dom_classes": [],
      "_model_name": "HBoxModel",
      "_view_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_view_count": null,
      "_view_module_version": "1.5.0",
      "box_style": "",
      "layout": "IPY_MODEL_b03bd083afc84f52834d4a124e57c442",
      "_model_module": "@jupyter-widgets/controls",
      "children": [
       "IPY_MODEL_3ba9fe91731f4ff380cda2183ad72c4d",
       "IPY_MODEL_4e01b9cceffb4d178e97f1d1b36d62b0"
      ]
     }
    },
    "b03bd083afc84f52834d4a124e57c442": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_view_name": "LayoutView",
      "grid_template_rows": null,
      "right": null,
      "justify_content": null,
      "_view_module": "@jupyter-widgets/base",
      "overflow": null,
      "_model_module_version": "1.2.0",
      "_view_count": null,
      "flex_flow": null,
      "width": null,
      "min_width": null,
      "border": null,
      "align_items": null,
      "bottom": null,
      "_model_module": "@jupyter-widgets/base",
      "top": null,
      "grid_column": null,
      "overflow_y": null,
      "overflow_x": null,
      "grid_auto_flow": null,
      "grid_area": null,
      "grid_template_columns": null,
      "flex": null,
      "_model_name": "LayoutModel",
      "justify_items": null,
      "grid_row": null,
      "max_height": null,
      "align_content": null,
      "visibility": null,
      "align_self": null,
      "height": null,
      "min_height": null,
      "padding": null,
      "grid_auto_rows": null,
      "grid_gap": null,
      "max_width": null,
      "order": null,
      "_view_module_version": "1.2.0",
      "grid_template_areas": null,
      "object_position": null,
      "object_fit": null,
      "grid_auto_columns": null,
      "margin": null,
      "display": null,
      "left": null
     }
    },
    "3ba9fe91731f4ff380cda2183ad72c4d": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_view_name": "ProgressView",
      "style": "IPY_MODEL_b3ac7912f51746afae20616269dc69f0",
      "_dom_classes": [],
      "description": "Downloading: 100%",
      "_model_name": "FloatProgressModel",
      "bar_style": "success",
      "max": 456318,
      "_view_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "value": 456318,
      "_view_count": null,
      "_view_module_version": "1.5.0",
      "orientation": "horizontal",
      "min": 0,
      "description_tooltip": null,
      "_model_module": "@jupyter-widgets/controls",
      "layout": "IPY_MODEL_fb1263f5a3da42e6acab0e9a66551b32"
     }
    },
    "4e01b9cceffb4d178e97f1d1b36d62b0": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_view_name": "HTMLView",
      "style": "IPY_MODEL_3426927332884df4b1ec5544ed479c45",
      "_dom_classes": [],
      "description": "",
      "_model_name": "HTMLModel",
      "placeholder": "​",
      "_view_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "value": " 456k/456k [00:00&lt;00:00, 928kB/s]",
      "_view_count": null,
      "_view_module_version": "1.5.0",
      "description_tooltip": null,
      "_model_module": "@jupyter-widgets/controls",
      "layout": "IPY_MODEL_debf4a0a609a4cffa5cd686a0a23f6e2"
     }
    },
    "b3ac7912f51746afae20616269dc69f0": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_view_name": "StyleView",
      "_model_name": "ProgressStyleModel",
      "description_width": "initial",
      "_view_module": "@jupyter-widgets/base",
      "_model_module_version": "1.5.0",
      "_view_count": null,
      "_view_module_version": "1.2.0",
      "bar_color": null,
      "_model_module": "@jupyter-widgets/controls"
     }
    },
    "fb1263f5a3da42e6acab0e9a66551b32": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_view_name": "LayoutView",
      "grid_template_rows": null,
      "right": null,
      "justify_content": null,
      "_view_module": "@jupyter-widgets/base",
      "overflow": null,
      "_model_module_version": "1.2.0",
      "_view_count": null,
      "flex_flow": null,
      "width": null,
      "min_width": null,
      "border": null,
      "align_items": null,
      "bottom": null,
      "_model_module": "@jupyter-widgets/base",
      "top": null,
      "grid_column": null,
      "overflow_y": null,
      "overflow_x": null,
      "grid_auto_flow": null,
      "grid_area": null,
      "grid_template_columns": null,
      "flex": null,
      "_model_name": "LayoutModel",
      "justify_items": null,
      "grid_row": null,
      "max_height": null,
      "align_content": null,
      "visibility": null,
      "align_self": null,
      "height": null,
      "min_height": null,
      "padding": null,
      "grid_auto_rows": null,
      "grid_gap": null,
      "max_width": null,
      "order": null,
      "_view_module_version": "1.2.0",
      "grid_template_areas": null,
      "object_position": null,
      "object_fit": null,
      "grid_auto_columns": null,
      "margin": null,
      "display": null,
      "left": null
     }
    },
    "3426927332884df4b1ec5544ed479c45": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_view_name": "StyleView",
      "_model_name": "DescriptionStyleModel",
      "description_width": "",
      "_view_module": "@jupyter-widgets/base",
      "_model_module_version": "1.5.0",
      "_view_count": null,
      "_view_module_version": "1.2.0",
      "_model_module": "@jupyter-widgets/controls"
     }
    },
    "debf4a0a609a4cffa5cd686a0a23f6e2": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_view_name": "LayoutView",
      "grid_template_rows": null,
      "right": null,
      "justify_content": null,
      "_view_module": "@jupyter-widgets/base",
      "overflow": null,
      "_model_module_version": "1.2.0",
      "_view_count": null,
      "flex_flow": null,
      "width": null,
      "min_width": null,
      "border": null,
      "align_items": null,
      "bottom": null,
      "_model_module": "@jupyter-widgets/base",
      "top": null,
      "grid_column": null,
      "overflow_y": null,
      "overflow_x": null,
      "grid_auto_flow": null,
      "grid_area": null,
      "grid_template_columns": null,
      "flex": null,
      "_model_name": "LayoutModel",
      "justify_items": null,
      "grid_row": null,
      "max_height": null,
      "align_content": null,
      "visibility": null,
      "align_self": null,
      "height": null,
      "min_height": null,
      "padding": null,
      "grid_auto_rows": null,
      "grid_gap": null,
      "max_width": null,
      "order": null,
      "_view_module_version": "1.2.0",
      "grid_template_areas": null,
      "object_position": null,
      "object_fit": null,
      "grid_auto_columns": null,
      "margin": null,
      "display": null,
      "left": null
     }
    },
    "7197de8a39ab4e4186eabd061c730cf5": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_view_name": "HBoxView",
      "_dom_classes": [],
      "_model_name": "HBoxModel",
      "_view_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_view_count": null,
      "_view_module_version": "1.5.0",
      "box_style": "",
      "layout": "IPY_MODEL_c31661c9a1ff4bd681cc10d70fd287c0",
      "_model_module": "@jupyter-widgets/controls",
      "children": [
       "IPY_MODEL_bb11a4fbf16f488082aa2bac7443fd62",
       "IPY_MODEL_5360630d6b0541dbaafa7ff91610bd84"
      ]
     }
    },
    "c31661c9a1ff4bd681cc10d70fd287c0": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_view_name": "LayoutView",
      "grid_template_rows": null,
      "right": null,
      "justify_content": null,
      "_view_module": "@jupyter-widgets/base",
      "overflow": null,
      "_model_module_version": "1.2.0",
      "_view_count": null,
      "flex_flow": null,
      "width": null,
      "min_width": null,
      "border": null,
      "align_items": null,
      "bottom": null,
      "_model_module": "@jupyter-widgets/base",
      "top": null,
      "grid_column": null,
      "overflow_y": null,
      "overflow_x": null,
      "grid_auto_flow": null,
      "grid_area": null,
      "grid_template_columns": null,
      "flex": null,
      "_model_name": "LayoutModel",
      "justify_items": null,
      "grid_row": null,
      "max_height": null,
      "align_content": null,
      "visibility": null,
      "align_self": null,
      "height": null,
      "min_height": null,
      "padding": null,
      "grid_auto_rows": null,
      "grid_gap": null,
      "max_width": null,
      "order": null,
      "_view_module_version": "1.2.0",
      "grid_template_areas": null,
      "object_position": null,
      "object_fit": null,
      "grid_auto_columns": null,
      "margin": null,
      "display": null,
      "left": null
     }
    },
    "bb11a4fbf16f488082aa2bac7443fd62": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_view_name": "ProgressView",
      "style": "IPY_MODEL_f3883c133c62438db114c22cfc94e351",
      "_dom_classes": [],
      "description": "Downloading: 100%",
      "_model_name": "FloatProgressModel",
      "bar_style": "success",
      "max": 1355256,
      "_view_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "value": 1355256,
      "_view_count": null,
      "_view_module_version": "1.5.0",
      "orientation": "horizontal",
      "min": 0,
      "description_tooltip": null,
      "_model_module": "@jupyter-widgets/controls",
      "layout": "IPY_MODEL_74672c43909444c592955afd0d727a28"
     }
    },
    "5360630d6b0541dbaafa7ff91610bd84": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_view_name": "HTMLView",
      "style": "IPY_MODEL_67f7c3bc59384dc7812d7130c7cde450",
      "_dom_classes": [],
      "description": "",
      "_model_name": "HTMLModel",
      "placeholder": "​",
      "_view_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "value": " 1.36M/1.36M [00:03&lt;00:00, 404kB/s]",
      "_view_count": null,
      "_view_module_version": "1.5.0",
      "description_tooltip": null,
      "_model_module": "@jupyter-widgets/controls",
      "layout": "IPY_MODEL_bb7410ff9cb4484da73302993427b90a"
     }
    },
    "f3883c133c62438db114c22cfc94e351": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_view_name": "StyleView",
      "_model_name": "ProgressStyleModel",
      "description_width": "initial",
      "_view_module": "@jupyter-widgets/base",
      "_model_module_version": "1.5.0",
      "_view_count": null,
      "_view_module_version": "1.2.0",
      "bar_color": null,
      "_model_module": "@jupyter-widgets/controls"
     }
    },
    "74672c43909444c592955afd0d727a28": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_view_name": "LayoutView",
      "grid_template_rows": null,
      "right": null,
      "justify_content": null,
      "_view_module": "@jupyter-widgets/base",
      "overflow": null,
      "_model_module_version": "1.2.0",
      "_view_count": null,
      "flex_flow": null,
      "width": null,
      "min_width": null,
      "border": null,
      "align_items": null,
      "bottom": null,
      "_model_module": "@jupyter-widgets/base",
      "top": null,
      "grid_column": null,
      "overflow_y": null,
      "overflow_x": null,
      "grid_auto_flow": null,
      "grid_area": null,
      "grid_template_columns": null,
      "flex": null,
      "_model_name": "LayoutModel",
      "justify_items": null,
      "grid_row": null,
      "max_height": null,
      "align_content": null,
      "visibility": null,
      "align_self": null,
      "height": null,
      "min_height": null,
      "padding": null,
      "grid_auto_rows": null,
      "grid_gap": null,
      "max_width": null,
      "order": null,
      "_view_module_version": "1.2.0",
      "grid_template_areas": null,
      "object_position": null,
      "object_fit": null,
      "grid_auto_columns": null,
      "margin": null,
      "display": null,
      "left": null
     }
    },
    "67f7c3bc59384dc7812d7130c7cde450": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_view_name": "StyleView",
      "_model_name": "DescriptionStyleModel",
      "description_width": "",
      "_view_module": "@jupyter-widgets/base",
      "_model_module_version": "1.5.0",
      "_view_count": null,
      "_view_module_version": "1.2.0",
      "_model_module": "@jupyter-widgets/controls"
     }
    },
    "bb7410ff9cb4484da73302993427b90a": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_view_name": "LayoutView",
      "grid_template_rows": null,
      "right": null,
      "justify_content": null,
      "_view_module": "@jupyter-widgets/base",
      "overflow": null,
      "_model_module_version": "1.2.0",
      "_view_count": null,
      "flex_flow": null,
      "width": null,
      "min_width": null,
      "border": null,
      "align_items": null,
      "bottom": null,
      "_model_module": "@jupyter-widgets/base",
      "top": null,
      "grid_column": null,
      "overflow_y": null,
      "overflow_x": null,
      "grid_auto_flow": null,
      "grid_area": null,
      "grid_template_columns": null,
      "flex": null,
      "_model_name": "LayoutModel",
      "justify_items": null,
      "grid_row": null,
      "max_height": null,
      "align_content": null,
      "visibility": null,
      "align_self": null,
      "height": null,
      "min_height": null,
      "padding": null,
      "grid_auto_rows": null,
      "grid_gap": null,
      "max_width": null,
      "order": null,
      "_view_module_version": "1.2.0",
      "grid_template_areas": null,
      "object_position": null,
      "object_fit": null,
      "grid_auto_columns": null,
      "margin": null,
      "display": null,
      "left": null
     }
    }
   }
  }
 },
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pHIJVqHsh4An"
   },
   "source": [
    "# GPT-J-6B Inference Demo\n",
    "\n",
    "<a href=\"http://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
    "\n",
    "This notebook demonstrates how to run the [GPT-J-6B model](https://github.com/kingoflolz/mesh-transformer-jax/#GPT-J-6B). See the link for more details about the model, including evaluation metrics and credits."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8CMw_dSQKfhT"
   },
   "source": [
    "## Install Dependencies\n",
    "\n",
    "First we download the model and install some dependencies. This step takes at least 5 minutes (possibly longer depending on server load).\n",
    "\n",
    "!!! **Make sure you are using a TPU runtime!** !!!"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "n7xAFw-LOYfe",
    "outputId": "7bd7fd83-a40c-41a2-8a14-9a11fc2f150c"
   },
   "source": [
    "!apt install zstd\n",
    "\n",
    "# the \"slim\" version contain only bf16 weights and no optimizer parameters, which minimizes bandwidth and memory\n",
    "!time wget -c https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd\n",
    "\n",
    "!time tar -I zstd -xf step_383500_slim.tar.zstd\n",
    "\n",
    "!git clone https://github.com/kingoflolz/mesh-transformer-jax.git\n",
    "!pip install -r mesh-transformer-jax/requirements.txt\n",
    "\n",
    "# jax 0.2.12 is required due to a regression with xmap in 0.2.13\n",
    "!pip install mesh-transformer-jax/ jax==0.2.12 tensorflow==2.5.0"
   ],
   "execution_count": 1,
   "outputs": [
    {
     "output_type": "stream",
     "text": [
      "Reading package lists... Done\n",
      "Building dependency tree       \n",
      "Reading state information... Done\n",
      "The following NEW packages will be installed:\n",
      "  zstd\n",
      "0 upgraded, 1 newly installed, 0 to remove and 39 not upgraded.\n",
      "Need to get 278 kB of archives.\n",
      "After this operation, 1,141 kB of additional disk space will be used.\n",
      "Get:1 http://archive.ubuntu.com/ubuntu bionic-updates/universe amd64 zstd amd64 1.3.3+dfsg-2ubuntu1.2 [278 kB]\n",
      "Fetched 278 kB in 1s (369 kB/s)\n",
      "Selecting previously unselected package zstd.\n",
      "(Reading database ... 160772 files and directories currently installed.)\n",
      "Preparing to unpack .../zstd_1.3.3+dfsg-2ubuntu1.2_amd64.deb ...\n",
      "Unpacking zstd (1.3.3+dfsg-2ubuntu1.2) ...\n",
      "Setting up zstd (1.3.3+dfsg-2ubuntu1.2) ...\n",
      "Processing triggers for man-db (2.8.3-2ubuntu0.1) ...\n",
      "--2021-06-09 01:49:35--  https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd\n",
      "Resolving the-eye.eu (the-eye.eu)... 162.213.130.242\n",
      "Connecting to the-eye.eu (the-eye.eu)|162.213.130.242|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 9414712325 (8.8G) [application/octet-stream]\n",
      "Saving to: ‘step_383500_slim.tar.zstd’\n",
      "\n",
      "step_383500_slim.ta 100%[===================>]   8.77G  85.8MB/s    in 1m 44s  \n",
      "\n",
      "2021-06-09 01:51:19 (86.4 MB/s) - ‘step_383500_slim.tar.zstd’ saved [9414712325/9414712325]\n",
      "\n",
      "\n",
      "real\t1m44.272s\n",
      "user\t0m5.735s\n",
      "sys\t0m26.674s\n",
      "\n",
      "real\t1m28.680s\n",
      "user\t0m31.887s\n",
      "sys\t0m30.531s\n",
      "Cloning into 'mesh-transformer-jax'...\n",
      "remote: Enumerating objects: 444, done.\u001B[K\n",
      "remote: Counting objects: 100% (444/444), done.\u001B[K\n",
      "remote: Compressing objects: 100% (299/299), done.\u001B[K\n",
      "remote: Total 444 (delta 294), reused 280 (delta 130), pack-reused 0\u001B[K\n",
      "Receiving objects: 100% (444/444), 97.43 KiB | 1.37 MiB/s, done.\n",
      "Resolving deltas: 100% (294/294), done.\n",
      "Collecting git+https://github.com/deepmind/dm-haiku (from -r mesh-transformer-jax/requirements.txt (line 10))\n",
      "  Cloning https://github.com/deepmind/dm-haiku to /tmp/pip-req-build-rzypnds1\n",
      "  Running command git clone -q https://github.com/deepmind/dm-haiku /tmp/pip-req-build-rzypnds1\n",
      "Collecting git+https://github.com/EleutherAI/lm-evaluation-harness/ (from -r mesh-transformer-jax/requirements.txt (line 11))\n",
      "  Cloning https://github.com/EleutherAI/lm-evaluation-harness/ to /tmp/pip-req-build-60v7ec7y\n",
      "  Running command git clone -q https://github.com/EleutherAI/lm-evaluation-harness/ /tmp/pip-req-build-60v7ec7y\n",
      "Requirement already satisfied: numpy~=1.19.5 in /usr/local/lib/python3.7/dist-packages (from -r mesh-transformer-jax/requirements.txt (line 1)) (1.19.5)\n",
      "Collecting transformers~=4.4.2\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/ed/d5/f4157a376b8a79489a76ce6cfe147f4f3be1e029b7144fa7b8432e8acb26/transformers-4.4.2-py3-none-any.whl (2.0MB)\n",
      "\u001B[K     |████████████████████████████████| 2.0MB 3.1MB/s \n",
      "\u001B[?25hCollecting tqdm~=4.45.0\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/4a/1c/6359be64e8301b84160f6f6f7936bbfaaa5e9a4eab6cbc681db07600b949/tqdm-4.45.0-py2.py3-none-any.whl (60kB)\n",
      "\u001B[K     |████████████████████████████████| 61kB 6.9MB/s \n",
      "\u001B[?25hCollecting setuptools~=51.3.3\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/b2/81/509db0082c0d2ca2af307c6652ea422865de1f83c14b1e1f3549e415cfac/setuptools-51.3.3-py3-none-any.whl (786kB)\n",
      "\u001B[K     |████████████████████████████████| 788kB 19.5MB/s \n",
      "\u001B[?25hCollecting wandb~=0.10.22\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/6c/48/b199e2b3b341ac842108c5db4956091dd75d961cfa77aceb033e99cac20f/wandb-0.10.31-py2.py3-none-any.whl (1.8MB)\n",
      "\u001B[K     |████████████████████████████████| 1.8MB 25.5MB/s \n",
      "\u001B[?25hCollecting einops~=0.3.0\n",
      "  Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl\n",
      "Collecting requests~=2.25.1\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/29/c1/24814557f1d22c56d50280771a17307e6bf87b70727d975fd6b2ce6b014a/requests-2.25.1-py2.py3-none-any.whl (61kB)\n",
      "\u001B[K     |████████████████████████████████| 61kB 7.1MB/s \n",
      "\u001B[?25hCollecting fabric~=2.6.0\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/c1/9d/59df62b620985871a4ba7d8b509b84340bbd1573257e55a427ae2df2d56e/fabric-2.6.0-py2.py3-none-any.whl (53kB)\n",
      "\u001B[K     |████████████████████████████████| 61kB 7.1MB/s \n",
      "\u001B[?25hCollecting optax~=0.0.2\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/ec/7a/6259edd319ee7fa94dd23c54f15eff667f599d179e889af90fe0c204612c/optax-0.0.6-py3-none-any.whl (96kB)\n",
      "\u001B[K     |████████████████████████████████| 102kB 10.2MB/s \n",
      "\u001B[?25hCollecting ray~=1.2.0\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/11/14/15d0f0aec20a4674a996429160565a071688f27f49f789327ebed8188ffb/ray-1.2.0-cp37-cp37m-manylinux2014_x86_64.whl (47.5MB)\n",
      "\u001B[K     |████████████████████████████████| 47.5MB 101kB/s \n",
      "\u001B[?25hRequirement already satisfied: jax~=0.2.12 in /usr/local/lib/python3.7/dist-packages (from -r mesh-transformer-jax/requirements.txt (line 13)) (0.2.13)\n",
      "Requirement already satisfied: Flask~=1.1.2 in /usr/local/lib/python3.7/dist-packages (from -r mesh-transformer-jax/requirements.txt (line 14)) (1.1.4)\n",
      "Requirement already satisfied: cloudpickle~=1.3.0 in /usr/local/lib/python3.7/dist-packages (from -r mesh-transformer-jax/requirements.txt (line 15)) (1.3.0)\n",
      "Collecting tensorflow-cpu~=2.4.1\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/f0/7b/b4d3fbc3ad4d7b5255dff433c06c0105f7815ed584e41f47cb2a00b0c079/tensorflow_cpu-2.4.1-cp37-cp37m-manylinux2010_x86_64.whl (144.1MB)\n",
      "\u001B[K     |████████████████████████████████| 144.2MB 47kB/s \n",
      "\u001B[?25hCollecting google-cloud-storage~=1.36.2\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/f2/0e/da07ffa511daa559bcc209f9344d71a90ba4d7b391fb795e6282f86d2935/google_cloud_storage-1.36.2-py2.py3-none-any.whl (97kB)\n",
      "\u001B[K     |████████████████████████████████| 102kB 10.8MB/s \n",
      "\u001B[?25hRequirement already satisfied: smart_open[gcs] in /usr/local/lib/python3.7/dist-packages (from -r mesh-transformer-jax/requirements.txt (line 18)) (5.0.0)\n",
      "Collecting func_timeout\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/b3/0d/bf0567477f7281d9a3926c582bfef21bff7498fc0ffd3e9de21811896a0b/func_timeout-4.3.5.tar.gz (44kB)\n",
      "\u001B[K     |████████████████████████████████| 51kB 6.1MB/s \n",
      "\u001B[?25hCollecting ftfy\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/af/da/d215a091986e5f01b80f5145cff6f22e2dc57c6b048aab2e882a07018473/ftfy-6.0.3.tar.gz (64kB)\n",
      "\u001B[K     |████████████████████████████████| 71kB 8.1MB/s \n",
      "\u001B[?25hRequirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from dm-haiku==0.0.5.dev0->-r mesh-transformer-jax/requirements.txt (line 10)) (0.12.0)\n",
      "Collecting jmp>=0.0.2\n",
      "  Downloading https://files.pythonhosted.org/packages/ff/5c/1482f4a4a502e080af2ca54d7f80a60b5d4735f464c151666d583b78c226/jmp-0.0.2-py3-none-any.whl\n",
      "Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.7/dist-packages (from dm-haiku==0.0.5.dev0->-r mesh-transformer-jax/requirements.txt (line 10)) (0.8.9)\n",
      "Requirement already satisfied: typing_extensions in /usr/local/lib/python3.7/dist-packages (from dm-haiku==0.0.5.dev0->-r mesh-transformer-jax/requirements.txt (line 10)) (3.7.4.3)\n",
      "Collecting black==20.8b1\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/dc/7b/5a6bbe89de849f28d7c109f5ea87b65afa5124ad615f3419e71beb29dc96/black-20.8b1.tar.gz (1.1MB)\n",
      "\u001B[K     |████████████████████████████████| 1.1MB 48.5MB/s \n",
      "\u001B[?25h  Installing build dependencies ... \u001B[?25l\u001B[?25hdone\n",
      "  Getting requirements to build wheel ... \u001B[?25l\u001B[?25hdone\n",
      "    Preparing wheel metadata ... \u001B[?25l\u001B[?25hdone\n",
      "Collecting best_download>=0.0.6\n",
      "  Downloading https://files.pythonhosted.org/packages/5f/03/155d24ac8b3e1d6c21daddc357001045486c3a31beab505e6cfc77b4ee7e/best_download-0.0.6-py3-none-any.whl\n",
      "Collecting datasets>=1.2.1\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/08/a2/d4e1024c891506e1cee8f9d719d20831bac31cb5b7416983c4d2f65a6287/datasets-1.8.0-py3-none-any.whl (237kB)\n",
      "\u001B[K     |████████████████████████████████| 245kB 50.0MB/s \n",
      "\u001B[?25hRequirement already satisfied: click>=7.1 in /usr/local/lib/python3.7/dist-packages (from lm-eval-harness==0.0.1->-r mesh-transformer-jax/requirements.txt (line 11)) (7.1.2)\n",
      "Collecting scikit-learn>=0.24.1\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/a8/eb/a48f25c967526b66d5f1fa7a984594f0bf0a5afafa94a8c4dbc317744620/scikit_learn-0.24.2-cp37-cp37m-manylinux2010_x86_64.whl (22.3MB)\n",
      "\u001B[K     |████████████████████████████████| 22.3MB 1.2MB/s \n",
      "\u001B[?25hRequirement already satisfied: torch>=1.7 in /usr/local/lib/python3.7/dist-packages (from lm-eval-harness==0.0.1->-r mesh-transformer-jax/requirements.txt (line 11)) (1.8.1+cu101)\n",
      "Collecting sqlitedict==1.6.0\n",
      "  Downloading https://files.pythonhosted.org/packages/0f/1c/c757b93147a219cf1e25cef7e1ad9b595b7f802159493c45ce116521caff/sqlitedict-1.6.0.tar.gz\n",
      "Collecting pytablewriter==0.58.0\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/fd/e2/62b208cdb8771dee1849bd2b4ed129284e1efff7669985697e4c124c1000/pytablewriter-0.58.0-py3-none-any.whl (96kB)\n",
      "\u001B[K     |████████████████████████████████| 102kB 11.0MB/s \n",
      "\u001B[?25hCollecting sacrebleu==1.5.0\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/3b/7f/4fd83db8570288c3899d8e57666c2841403c15659f3d792a3cb8dc1c6689/sacrebleu-1.5.0-py3-none-any.whl (65kB)\n",
      "\u001B[K     |████████████████████████████████| 71kB 7.8MB/s \n",
      "\u001B[?25hCollecting pycountry==20.7.3\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/76/73/6f1a412f14f68c273feea29a6ea9b9f1e268177d32e0e69ad6790d306312/pycountry-20.7.3.tar.gz (10.1MB)\n",
      "\u001B[K     |████████████████████████████████| 10.1MB 53.9MB/s \n",
      "\u001B[?25hCollecting numexpr==2.7.2\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/9c/f4/fa8755c1aa44b431267aa019922f6cc9ec099cef0c6fc0ead0f9a2aa59e5/numexpr-2.7.2-cp37-cp37m-manylinux2010_x86_64.whl (471kB)\n",
      "\u001B[K     |████████████████████████████████| 481kB 45.7MB/s \n",
      "\u001B[?25hCollecting lm_dataformat==0.0.19\n",
      "  Downloading https://files.pythonhosted.org/packages/83/b5/8d10bf5a8082921792bb09c9d591dfd622cf4a16fbb7e283cc921c5ffc50/lm_dataformat-0.0.19-py3-none-any.whl\n",
      "Collecting pytest==6.2.3\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/76/4d/9c00146923da9f1cabd1878209d71b1380d537ec331a1a613e8f4b9d7985/pytest-6.2.3-py3-none-any.whl (280kB)\n",
      "\u001B[K     |████████████████████████████████| 286kB 49.2MB/s \n",
      "\u001B[?25hCollecting pybind11==2.6.2\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/8d/43/7339dbabbc2793718d59703aace4166f53c29ee1c202f6ff5bf8a26c4d91/pybind11-2.6.2-py2.py3-none-any.whl (191kB)\n",
      "\u001B[K     |████████████████████████████████| 194kB 55.4MB/s \n",
      "\u001B[?25hCollecting tqdm-multiprocess==0.0.11\n",
      "  Downloading https://files.pythonhosted.org/packages/25/7e/0d889fc6c84e3df6b69aaafe893fc77f69b3d968ac9ce574d1c62c688050/tqdm_multiprocess-0.0.11-py3-none-any.whl\n",
      "Collecting zstandard==0.15.2\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/5b/56/dc2a85d06e973f2ad96584b0e5b876d063135d449cb040aeaadd22a910f9/zstandard-0.15.2-cp37-cp37m-manylinux2014_x86_64.whl (2.2MB)\n",
      "\u001B[K     |████████████████████████████████| 2.2MB 52.7MB/s \n",
      "\u001B[?25hCollecting jsonlines==2.0.0\n",
      "  Downloading https://files.pythonhosted.org/packages/d4/58/06f430ff7607a2929f80f07bfd820acbc508a4e977542fefcc522cde9dff/jsonlines-2.0.0-py3-none-any.whl\n",
      "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from transformers~=4.4.2->-r mesh-transformer-jax/requirements.txt (line 2)) (4.0.1)\n",
      "Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from transformers~=4.4.2->-r mesh-transformer-jax/requirements.txt (line 2)) (20.9)\n",
      "Collecting tokenizers<0.11,>=0.10.1\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/d4/e2/df3543e8ffdab68f5acc73f613de9c2b155ac47f162e725dcac87c521c11/tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3MB)\n",
      "\u001B[K     |████████████████████████████████| 3.3MB 50.0MB/s \n",
      "\u001B[?25hRequirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers~=4.4.2->-r mesh-transformer-jax/requirements.txt (line 2)) (2019.12.20)\n",
      "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers~=4.4.2->-r mesh-transformer-jax/requirements.txt (line 2)) (3.0.12)\n",
      "Collecting sacremoses\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/75/ee/67241dc87f266093c533a2d4d3d69438e57d7a90abb216fa076e7d475d4a/sacremoses-0.0.45-py3-none-any.whl (895kB)\n",
      "\u001B[K     |████████████████████████████████| 901kB 53.3MB/s \n",
      "\u001B[?25hCollecting configparser>=3.8.1\n",
      "  Downloading https://files.pythonhosted.org/packages/fd/01/ff260a18caaf4457eb028c96eeb405c4a230ca06c8ec9c1379f813caa52e/configparser-5.0.2-py3-none-any.whl\n",
      "Requirement already satisfied: six>=1.13.0 in /usr/local/lib/python3.7/dist-packages (from wandb~=0.10.22->-r mesh-transformer-jax/requirements.txt (line 5)) (1.15.0)\n",
      "Collecting docker-pycreds>=0.4.0\n",
      "  Downloading https://files.pythonhosted.org/packages/f5/e8/f6bd1eee09314e7e6dee49cbe2c5e22314ccdb38db16c9fc72d2fa80d054/docker_pycreds-0.4.0-py2.py3-none-any.whl\n",
      "Requirement already satisfied: protobuf>=3.12.0 in /usr/local/lib/python3.7/dist-packages (from wandb~=0.10.22->-r mesh-transformer-jax/requirements.txt (line 5)) (3.12.4)\n",
      "Collecting sentry-sdk>=0.4.0\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/1c/4a/a54b254f67d8f4052338d54ebe90126f200693440a93ef76d254d581e3ec/sentry_sdk-1.1.0-py2.py3-none-any.whl (131kB)\n",
      "\u001B[K     |████████████████████████████████| 133kB 53.9MB/s \n",
      "\u001B[?25hRequirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.7/dist-packages (from wandb~=0.10.22->-r mesh-transformer-jax/requirements.txt (line 5)) (2.3)\n",
      "Collecting GitPython>=1.0.0\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/27/da/6f6224fdfc47dab57881fe20c0d1bc3122be290198ba0bf26a953a045d92/GitPython-3.1.17-py3-none-any.whl (166kB)\n",
      "\u001B[K     |████████████████████████████████| 174kB 52.4MB/s \n",
      "\u001B[?25hCollecting shortuuid>=0.5.0\n",
      "  Downloading https://files.pythonhosted.org/packages/25/a6/2ecc1daa6a304e7f1b216f0896b26156b78e7c38e1211e9b798b4716c53d/shortuuid-1.0.1-py3-none-any.whl\n",
      "Collecting pathtools\n",
      "  Downloading https://files.pythonhosted.org/packages/e7/7f/470d6fcdf23f9f3518f6b0b76be9df16dcc8630ad409947f8be2eb0ed13a/pathtools-0.1.2.tar.gz\n",
      "Requirement already satisfied: PyYAML in /usr/local/lib/python3.7/dist-packages (from wandb~=0.10.22->-r mesh-transformer-jax/requirements.txt (line 5)) (3.13)\n",
      "Collecting subprocess32>=3.5.3\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/32/c8/564be4d12629b912ea431f1a50eb8b3b9d00f1a0b1ceff17f266be190007/subprocess32-3.5.4.tar.gz (97kB)\n",
      "\u001B[K     |████████████████████████████████| 102kB 10.1MB/s \n",
      "\u001B[?25hRequirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb~=0.10.22->-r mesh-transformer-jax/requirements.txt (line 5)) (5.4.8)\n",
      "Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.7/dist-packages (from wandb~=0.10.22->-r mesh-transformer-jax/requirements.txt (line 5)) (2.8.1)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests~=2.25.1->-r mesh-transformer-jax/requirements.txt (line 7)) (1.24.3)\n",
      "Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests~=2.25.1->-r mesh-transformer-jax/requirements.txt (line 7)) (3.0.4)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests~=2.25.1->-r mesh-transformer-jax/requirements.txt (line 7)) (2020.12.5)\n",
      "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests~=2.25.1->-r mesh-transformer-jax/requirements.txt (line 7)) (2.10)\n",
      "Collecting invoke<2.0,>=1.3\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/87/8f/c153d7db091f342da6bc97f7bedd1b2ce2867c4a8b0aab40fbba85a05e33/invoke-1.5.0-py3-none-any.whl (211kB)\n",
      "\u001B[K     |████████████████████████████████| 215kB 52.0MB/s \n",
      "\u001B[?25hCollecting paramiko>=2.4\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/95/19/124e9287b43e6ff3ebb9cdea3e5e8e88475a873c05ccdf8b7e20d2c4201e/paramiko-2.7.2-py2.py3-none-any.whl (206kB)\n",
      "\u001B[K     |████████████████████████████████| 215kB 51.7MB/s \n",
      "\u001B[?25hCollecting pathlib2\n",
      "  Downloading https://files.pythonhosted.org/packages/e9/45/9c82d3666af4ef9f221cbb954e1d77ddbb513faf552aea6df5f37f1a4859/pathlib2-2.3.5-py2.py3-none-any.whl\n",
      "Collecting chex>=0.0.4\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/f5/b9/445eb59ec23249acffc5322c79b07e20b12dbff45b9c1da6cdae9e947685/chex-0.0.7-py3-none-any.whl (52kB)\n",
      "\u001B[K     |████████████████████████████████| 61kB 7.2MB/s \n",
      "\u001B[?25hRequirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax~=0.0.2->-r mesh-transformer-jax/requirements.txt (line 9)) (0.1.66+cuda110)\n",
      "Collecting aiohttp\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/88/c0/5890b4c8b04a79b7360e8fe4490feb0bb3ab179743f199f0e6220cebd568/aiohttp-3.7.4.post0-cp37-cp37m-manylinux2014_x86_64.whl (1.3MB)\n",
      "\u001B[K     |████████████████████████████████| 1.3MB 47.2MB/s \n",
      "\u001B[?25hRequirement already satisfied: prometheus-client>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from ray~=1.2.0->-r mesh-transformer-jax/requirements.txt (line 12)) (0.10.1)\n",
      "Collecting colorful\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/b0/8e/e386e248266952d24d73ed734c2f5513f34d9557032618c8910e605dfaf6/colorful-0.5.4-py2.py3-none-any.whl (201kB)\n",
      "\u001B[K     |████████████████████████████████| 204kB 54.3MB/s \n",
      "\u001B[?25hCollecting gpustat\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/b4/69/d8c849715171aeabd61af7da080fdc60948b5a396d2422f1f4672e43d008/gpustat-0.6.0.tar.gz (78kB)\n",
      "\u001B[K     |████████████████████████████████| 81kB 9.8MB/s \n",
      "\u001B[?25hRequirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from ray~=1.2.0->-r mesh-transformer-jax/requirements.txt (line 12)) (1.0.2)\n",
      "Collecting aiohttp-cors\n",
      "  Downloading https://files.pythonhosted.org/packages/13/e7/e436a0c0eb5127d8b491a9b83ecd2391c6ff7dcd5548dfaec2080a2340fd/aiohttp_cors-0.7.0-py3-none-any.whl\n",
      "Collecting redis>=3.5.0\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/a7/7c/24fb0511df653cf1a5d938d8f5d19802a88cef255706fdda242ff97e91b7/redis-3.5.3-py2.py3-none-any.whl (72kB)\n",
      "\u001B[K     |████████████████████████████████| 81kB 9.5MB/s \n",
      "\u001B[?25hCollecting colorama\n",
      "  Downloading https://files.pythonhosted.org/packages/44/98/5b86278fbbf250d239ae0ecb724f8572af1c91f4a11edf4d36a206189440/colorama-0.4.4-py2.py3-none-any.whl\n",
      "Collecting aioredis\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/b0/64/1b1612d0a104f21f80eb4c6e1b6075f2e6aba8e228f46f229cfd3fdac859/aioredis-1.3.1-py3-none-any.whl (65kB)\n",
      "\u001B[K     |████████████████████████████████| 71kB 8.6MB/s \n",
      "\u001B[?25hRequirement already satisfied: jsonschema in /usr/local/lib/python3.7/dist-packages (from ray~=1.2.0->-r mesh-transformer-jax/requirements.txt (line 12)) (2.6.0)\n",
      "Requirement already satisfied: grpcio>=1.28.1 in /usr/local/lib/python3.7/dist-packages (from ray~=1.2.0->-r mesh-transformer-jax/requirements.txt (line 12)) (1.34.1)\n",
      "Collecting py-spy>=0.2.0\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/9d/4d/1a9cbe9a0b543e6733cb38afe26451522a9ef8e4897b59e74cc76838f245/py_spy-0.3.7-py2.py3-none-manylinux1_x86_64.whl (3.1MB)\n",
      "\u001B[K     |████████████████████████████████| 3.1MB 50.0MB/s \n",
      "\u001B[?25hCollecting opencensus\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/18/59/12044123133d000f705383ad98579aeb0dd82d66b33a254a21b54bf0d6bb/opencensus-0.7.13-py2.py3-none-any.whl (127kB)\n",
      "\u001B[K     |████████████████████████████████| 133kB 50.4MB/s \n",
      "\u001B[?25hRequirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax~=0.2.12->-r mesh-transformer-jax/requirements.txt (line 13)) (3.3.0)\n",
      "Requirement already satisfied: Werkzeug<2.0,>=0.15 in /usr/local/lib/python3.7/dist-packages (from Flask~=1.1.2->-r mesh-transformer-jax/requirements.txt (line 14)) (1.0.1)\n",
      "Requirement already satisfied: itsdangerous<2.0,>=0.24 in /usr/local/lib/python3.7/dist-packages (from Flask~=1.1.2->-r mesh-transformer-jax/requirements.txt (line 14)) (1.1.0)\n",
      "Requirement already satisfied: Jinja2<3.0,>=2.10.1 in /usr/local/lib/python3.7/dist-packages (from Flask~=1.1.2->-r mesh-transformer-jax/requirements.txt (line 14)) (2.11.3)\n",
      "Requirement already satisfied: wheel~=0.35 in /usr/local/lib/python3.7/dist-packages (from tensorflow-cpu~=2.4.1->-r mesh-transformer-jax/requirements.txt (line 16)) (0.36.2)\n",
      "Requirement already satisfied: flatbuffers~=1.12.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow-cpu~=2.4.1->-r mesh-transformer-jax/requirements.txt (line 16)) (1.12)\n",
      "Requirement already satisfied: wrapt~=1.12.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow-cpu~=2.4.1->-r mesh-transformer-jax/requirements.txt (line 16)) (1.12.1)\n",
      "Collecting gast==0.3.3\n",
      "  Downloading https://files.pythonhosted.org/packages/d6/84/759f5dd23fec8ba71952d97bcc7e2c9d7d63bdc582421f3cd4be845f0c98/gast-0.3.3-py2.py3-none-any.whl\n",
      "Requirement already satisfied: termcolor~=1.1.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow-cpu~=2.4.1->-r mesh-transformer-jax/requirements.txt (line 16)) (1.1.0)\n",
      "Collecting h5py~=2.10.0\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/3f/c0/abde58b837e066bca19a3f7332d9d0493521d7dd6b48248451a9e3fe2214/h5py-2.10.0-cp37-cp37m-manylinux1_x86_64.whl (2.9MB)\n",
      "\u001B[K     |████████████████████████████████| 2.9MB 51.5MB/s \n",
      "\u001B[?25hCollecting tensorflow-estimator<2.5.0,>=2.4.0\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/74/7e/622d9849abf3afb81e482ffc170758742e392ee129ce1540611199a59237/tensorflow_estimator-2.4.0-py2.py3-none-any.whl (462kB)\n",
      "\u001B[K     |████████████████████████████████| 471kB 47.3MB/s \n",
      "\u001B[?25hRequirement already satisfied: google-pasta~=0.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow-cpu~=2.4.1->-r mesh-transformer-jax/requirements.txt (line 16)) (0.2.0)\n",
      "Requirement already satisfied: keras-preprocessing~=1.1.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow-cpu~=2.4.1->-r mesh-transformer-jax/requirements.txt (line 16)) (1.1.2)\n",
      "Requirement already satisfied: tensorboard~=2.4 in /usr/local/lib/python3.7/dist-packages (from tensorflow-cpu~=2.4.1->-r mesh-transformer-jax/requirements.txt (line 16)) (2.5.0)\n",
      "Requirement already satisfied: astunparse~=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow-cpu~=2.4.1->-r mesh-transformer-jax/requirements.txt (line 16)) (1.6.3)\n",
      "Requirement already satisfied: google-auth<2.0dev,>=1.11.0 in /usr/local/lib/python3.7/dist-packages (from google-cloud-storage~=1.36.2->-r mesh-transformer-jax/requirements.txt (line 17)) (1.30.0)\n",
      "Collecting google-resumable-media<2.0dev,>=1.2.0\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/f9/ad/bc80b0b33ccb5e21375ca1440da9dab99596948d5035e2f597fdcffb31f1/google_resumable_media-1.3.0-py2.py3-none-any.whl (75kB)\n",
      "\u001B[K     |████████████████████████████████| 81kB 9.7MB/s \n",
      "\u001B[?25hCollecting google-cloud-core<2.0dev,>=1.4.1\n",
      "  Downloading https://files.pythonhosted.org/packages/ad/fc/6e8c449185cb8862af353c1164100ff75e32d55ba1de3baf9eaa01b7d2a9/google_cloud_core-1.6.0-py2.py3-none-any.whl\n",
      "Requirement already satisfied: wcwidth in /usr/local/lib/python3.7/dist-packages (from ftfy->-r mesh-transformer-jax/requirements.txt (line 20)) (0.2.5)\n",
      "Collecting typed-ast>=1.4.0\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/65/b3/573d2f1fecbbe8f82a8d08172e938c247f99abe1be3bef3da2efaa3810bf/typed_ast-1.4.3-cp37-cp37m-manylinux1_x86_64.whl (743kB)\n",
      "\u001B[K     |████████████████████████████████| 747kB 52.5MB/s \n",
      "\u001B[?25hRequirement already satisfied: appdirs in /usr/local/lib/python3.7/dist-packages (from black==20.8b1->lm-eval-harness==0.0.1->-r mesh-transformer-jax/requirements.txt (line 11)) (1.4.4)\n",
      "Collecting mypy-extensions>=0.4.3\n",
      "  Downloading https://files.pythonhosted.org/packages/5c/eb/975c7c080f3223a5cdaff09612f3a5221e4ba534f7039db34c35d95fa6a5/mypy_extensions-0.4.3-py2.py3-none-any.whl\n",
      "Collecting pathspec<1,>=0.6\n",
      "  Downloading https://files.pythonhosted.org/packages/29/29/a465741a3d97ea3c17d21eaad4c64205428bde56742360876c4391f930d4/pathspec-0.8.1-py2.py3-none-any.whl\n",
      "Requirement already satisfied: toml>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from black==20.8b1->lm-eval-harness==0.0.1->-r mesh-transformer-jax/requirements.txt (line 11)) (0.10.2)\n",
      "Collecting rehash\n",
      "  Downloading https://files.pythonhosted.org/packages/c9/e4/30db193232d9e9c8e123764d84f0807535677548833ca251556ad6134c24/rehash-1.0.0-py2.py3-none-any.whl\n",
      "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from datasets>=1.2.1->lm-eval-harness==0.0.1->-r mesh-transformer-jax/requirements.txt (line 11)) (1.1.5)\n",
      "Requirement already satisfied: pyarrow<4.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from datasets>=1.2.1->lm-eval-harness==0.0.1->-r mesh-transformer-jax/requirements.txt (line 11)) (3.0.0)\n",
      "Requirement already satisfied: dill in /usr/local/lib/python3.7/dist-packages (from datasets>=1.2.1->lm-eval-harness==0.0.1->-r mesh-transformer-jax/requirements.txt (line 11)) (0.3.3)\n",
      "Collecting xxhash\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/7d/4f/0a862cad26aa2ed7a7cd87178cbbfa824fc1383e472d63596a0d018374e7/xxhash-2.0.2-cp37-cp37m-manylinux2010_x86_64.whl (243kB)\n",
      "\u001B[K     |████████████████████████████████| 245kB 51.5MB/s \n",
      "\u001B[?25hCollecting fsspec\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/8e/d2/d05466997f7751a2c06a7a416b7d1f131d765f7916698d3fdcb3a4d037e5/fsspec-2021.6.0-py3-none-any.whl (114kB)\n",
      "\u001B[K     |████████████████████████████████| 122kB 50.9MB/s \n",
      "\u001B[?25hCollecting huggingface-hub<0.1.0\n",
      "  Downloading https://files.pythonhosted.org/packages/3c/e3/fb7b6aefaf0fc7b792cebbbd590b1895c022ab0ff27f389e1019c6f2e68a/huggingface_hub-0.0.10-py3-none-any.whl\n",
      "Requirement already satisfied: multiprocess in /usr/local/lib/python3.7/dist-packages (from datasets>=1.2.1->lm-eval-harness==0.0.1->-r mesh-transformer-jax/requirements.txt (line 11)) (0.70.11.1)\n",
      "Collecting threadpoolctl>=2.0.0\n",
      "  Downloading https://files.pythonhosted.org/packages/f7/12/ec3f2e203afa394a149911729357aa48affc59c20e2c1c8297a60f33f133/threadpoolctl-2.1.0-py3-none-any.whl\n",
      "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.24.1->lm-eval-harness==0.0.1->-r mesh-transformer-jax/requirements.txt (line 11)) (1.0.1)\n",
      "Requirement already satisfied: scipy>=0.19.1 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.24.1->lm-eval-harness==0.0.1->-r mesh-transformer-jax/requirements.txt (line 11)) (1.4.1)\n",
      "Collecting tcolorpy<1,>=0.0.5\n",
      "  Downloading https://files.pythonhosted.org/packages/8e/d5/2b53921011bef9a5a4d64f3265a40b5bdfef6e0eda937d57a2dd7a66f89c/tcolorpy-0.0.9-py3-none-any.whl\n",
      "Collecting msgfy<1,>=0.1.0\n",
      "  Downloading https://files.pythonhosted.org/packages/48/52/c4441871514276e7c4cb51c122e663b5ef19dc20030f6ab7723071118464/msgfy-0.1.0-py3-none-any.whl\n",
      "Collecting DataProperty<2,>=0.50.0\n",
      "  Downloading https://files.pythonhosted.org/packages/b9/5f/c773c362fcba227d6a4021225cb0213b51849c7dc9c93004d34d9004078b/DataProperty-0.50.1-py3-none-any.whl\n",
      "Collecting typepy[datetime]<2,>=1.1.1\n",
      "  Downloading https://files.pythonhosted.org/packages/60/3a/1239e59924250d9c2dd1d5b84748da82d15aaa241b3ceeffa08aa5eba589/typepy-1.1.5-py3-none-any.whl\n",
      "Collecting mbstrdecoder<2,>=1.0.0\n",
      "  Downloading https://files.pythonhosted.org/packages/e8/f6/0e6bb50c3c6380a4982c87d80e70b2f6e366523a57a0c58594aea472206d/mbstrdecoder-1.0.1-py3-none-any.whl\n",
      "Collecting pathvalidate<3,>=2.3.0\n",
      "  Downloading https://files.pythonhosted.org/packages/87/55/7d63b78986f1f8764180b84ee3e8a47c583ec059d32c98d8fba7fc0dc1ae/pathvalidate-2.4.1-py3-none-any.whl\n",
      "Collecting tabledata<2,>=1.1.3\n",
      "  Downloading https://files.pythonhosted.org/packages/85/93/4c695da7e6589e1e4b513c02d5b562dcc5afacb8a2f6cac8eb2ac2e88833/tabledata-1.1.4-py3-none-any.whl\n",
      "Collecting portalocker\n",
      "  Downloading https://files.pythonhosted.org/packages/68/33/cb524f4de298509927b90aa5ee34767b9a2b93e663cf354b2a3efa2b4acd/portalocker-2.3.0-py2.py3-none-any.whl\n",
      "Collecting ujson\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/17/4e/50e8e4cf5f00b537095711c2c86ac4d7191aed2b4fffd5a19f06898f6929/ujson-4.0.2-cp37-cp37m-manylinux1_x86_64.whl (179kB)\n",
      "\u001B[K     |████████████████████████████████| 184kB 55.9MB/s \n",
      "\u001B[?25hRequirement already satisfied: iniconfig in /usr/local/lib/python3.7/dist-packages (from pytest==6.2.3->lm-eval-harness==0.0.1->-r mesh-transformer-jax/requirements.txt (line 11)) (1.1.1)\n",
      "Collecting pluggy<1.0.0a1,>=0.12\n",
      "  Downloading https://files.pythonhosted.org/packages/a0/28/85c7aa31b80d150b772fbe4a229487bc6644da9ccb7e427dd8cc60cb8a62/pluggy-0.13.1-py2.py3-none-any.whl\n",
      "Requirement already satisfied: py>=1.8.2 in /usr/local/lib/python3.7/dist-packages (from pytest==6.2.3->lm-eval-harness==0.0.1->-r mesh-transformer-jax/requirements.txt (line 11)) (1.10.0)\n",
      "Requirement already satisfied: attrs>=19.2.0 in /usr/local/lib/python3.7/dist-packages (from pytest==6.2.3->lm-eval-harness==0.0.1->-r mesh-transformer-jax/requirements.txt (line 11)) (21.2.0)\n",
      "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->transformers~=4.4.2->-r mesh-transformer-jax/requirements.txt (line 2)) (3.4.1)\n",
      "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->transformers~=4.4.2->-r mesh-transformer-jax/requirements.txt (line 2)) (2.4.7)\n",
      "Collecting gitdb<5,>=4.0.1\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/ea/e8/f414d1a4f0bbc668ed441f74f44c116d9816833a48bf81d22b697090dba8/gitdb-4.0.7-py3-none-any.whl (63kB)\n",
      "\u001B[K     |████████████████████████████████| 71kB 8.1MB/s \n",
      "\u001B[?25hCollecting cryptography>=2.5\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/b2/26/7af637e6a7e87258b963f1731c5982fb31cd507f0d90d91836e446955d02/cryptography-3.4.7-cp36-abi3-manylinux2014_x86_64.whl (3.2MB)\n",
      "\u001B[K     |████████████████████████████████| 3.2MB 49.6MB/s \n",
      "\u001B[?25hCollecting pynacl>=1.0.1\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/9d/57/2f5e6226a674b2bcb6db531e8b383079b678df5b10cdaa610d6cf20d77ba/PyNaCl-1.4.0-cp35-abi3-manylinux1_x86_64.whl (961kB)\n",
      "\u001B[K     |████████████████████████████████| 962kB 47.6MB/s \n",
      "\u001B[?25hCollecting bcrypt>=3.1.3\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/26/70/6d218afbe4c73538053c1016dd631e8f25fffc10cd01f5c272d7acf3c03d/bcrypt-3.2.0-cp36-abi3-manylinux2010_x86_64.whl (63kB)\n",
      "\u001B[K     |████████████████████████████████| 71kB 8.2MB/s \n",
      "\u001B[?25hRequirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax~=0.0.2->-r mesh-transformer-jax/requirements.txt (line 9)) (0.1.6)\n",
      "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax~=0.0.2->-r mesh-transformer-jax/requirements.txt (line 9)) (0.11.1)\n",
      "Collecting yarl<2.0,>=1.0\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/f1/62/046834c5fc998c88ab2ef722f5d42122230a632212c8afa76418324f53ff/yarl-1.6.3-cp37-cp37m-manylinux2014_x86_64.whl (294kB)\n",
      "\u001B[K     |████████████████████████████████| 296kB 42.9MB/s \n",
      "\u001B[?25hCollecting async-timeout<4.0,>=3.0\n",
      "  Downloading https://files.pythonhosted.org/packages/e1/1e/5a4441be21b0726c4464f3f23c8b19628372f606755a9d2e46c187e65ec4/async_timeout-3.0.1-py3-none-any.whl\n",
      "Collecting multidict<7.0,>=4.5\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/7c/a6/4123b8165acbe773d1a8dc8e3f0d1edea16d29f7de018eda769abb56bd30/multidict-5.1.0-cp37-cp37m-manylinux2014_x86_64.whl (142kB)\n",
      "\u001B[K     |████████████████████████████████| 143kB 49.7MB/s \n",
      "\u001B[?25hRequirement already satisfied: nvidia-ml-py3>=7.352.0 in /usr/local/lib/python3.7/dist-packages (from gpustat->ray~=1.2.0->-r mesh-transformer-jax/requirements.txt (line 12)) (7.352.0)\n",
      "Collecting blessings>=1.6\n",
      "  Downloading https://files.pythonhosted.org/packages/03/74/489f85a78247609c6b4f13733cbf3ba0d864b11aa565617b645d6fdf2a4a/blessings-1.7-py3-none-any.whl\n",
      "Collecting hiredis\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/ed/33/290cea35b09c80b4634773ad5572a8030a87b5d39736719f698f521d2a13/hiredis-2.0.0-cp37-cp37m-manylinux2010_x86_64.whl (85kB)\n",
      "\u001B[K     |████████████████████████████████| 92kB 9.7MB/s \n",
      "\u001B[?25hCollecting opencensus-context==0.1.2\n",
      "  Downloading https://files.pythonhosted.org/packages/f1/33/990f1bd9e7ee770fc8d3c154fc24743a96f16a0e49e14e1b7540cc2fdd93/opencensus_context-0.1.2-py2.py3-none-any.whl\n",
      "Requirement already satisfied: google-api-core<2.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from opencensus->ray~=1.2.0->-r mesh-transformer-jax/requirements.txt (line 12)) (1.26.3)\n",
      "Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from Jinja2<3.0,>=2.10.1->Flask~=1.1.2->-r mesh-transformer-jax/requirements.txt (line 14)) (2.0.1)\n",
      "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.4->tensorflow-cpu~=2.4.1->-r mesh-transformer-jax/requirements.txt (line 16)) (1.8.0)\n",
      "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.4->tensorflow-cpu~=2.4.1->-r mesh-transformer-jax/requirements.txt (line 16)) (3.3.4)\n",
      "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.4->tensorflow-cpu~=2.4.1->-r mesh-transformer-jax/requirements.txt (line 16)) (0.6.1)\n",
      "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.4->tensorflow-cpu~=2.4.1->-r mesh-transformer-jax/requirements.txt (line 16)) (0.4.4)\n",
      "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<2.0dev,>=1.11.0->google-cloud-storage~=1.36.2->-r mesh-transformer-jax/requirements.txt (line 17)) (4.2.2)\n",
      "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<2.0dev,>=1.11.0->google-cloud-storage~=1.36.2->-r mesh-transformer-jax/requirements.txt (line 17)) (0.2.8)\n",
      "Requirement already satisfied: rsa<5,>=3.1.4; python_version >= \"3.6\" in /usr/local/lib/python3.7/dist-packages (from google-auth<2.0dev,>=1.11.0->google-cloud-storage~=1.36.2->-r mesh-transformer-jax/requirements.txt (line 17)) (4.7.2)\n",
      "Collecting google-crc32c<2.0dev,>=1.0; python_version >= \"3.5\"\n",
      "  Downloading https://files.pythonhosted.org/packages/fc/ae/b6efa1019e18c6c791f0f5cd93b2ff40f8f06696dbf04db39ec0f5591b1e/google_crc32c-1.1.2-cp37-cp37m-manylinux2014_x86_64.whl\n",
      "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets>=1.2.1->lm-eval-harness==0.0.1->-r mesh-transformer-jax/requirements.txt (line 11)) (2018.9)\n",
      "Collecting smmap<5,>=3.0.1\n",
      "  Downloading https://files.pythonhosted.org/packages/68/ee/d540eb5e5996eb81c26ceffac6ee49041d473bc5125f2aa995cf51ec1cf1/smmap-4.0.0-py2.py3-none-any.whl\n",
      "Requirement already satisfied: cffi>=1.12 in /usr/local/lib/python3.7/dist-packages (from cryptography>=2.5->paramiko>=2.4->fabric~=2.6.0->-r mesh-transformer-jax/requirements.txt (line 8)) (1.14.5)\n",
      "Requirement already satisfied: googleapis-common-protos<2.0dev,>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<2.0.0,>=1.0.0->opencensus->ray~=1.2.0->-r mesh-transformer-jax/requirements.txt (line 12)) (1.53.0)\n",
      "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.4->tensorflow-cpu~=2.4.1->-r mesh-transformer-jax/requirements.txt (line 16)) (1.3.0)\n",
      "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<2.0dev,>=1.11.0->google-cloud-storage~=1.36.2->-r mesh-transformer-jax/requirements.txt (line 17)) (0.4.8)\n",
      "Requirement already satisfied: pycparser in /usr/local/lib/python3.7/dist-packages (from cffi>=1.12->cryptography>=2.5->paramiko>=2.4->fabric~=2.6.0->-r mesh-transformer-jax/requirements.txt (line 8)) (2.20)\n",
      "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.4->tensorflow-cpu~=2.4.1->-r mesh-transformer-jax/requirements.txt (line 16)) (3.1.0)\n",
      "Building wheels for collected packages: black\n",
      "  Building wheel for black (PEP 517) ... \u001B[?25l\u001B[?25hdone\n",
      "  Created wheel for black: filename=black-20.8b1-cp37-none-any.whl size=124195 sha256=820b5c710b946b935fabc8b39f608ff2fab3519bd888db1b69710011b7aed5e7\n",
      "  Stored in directory: /root/.cache/pip/wheels/6e/10/b5/edf7359c2edd0305cce7e3f96e07daf7ce55dceac9d3ce3373\n",
      "Successfully built black\n",
      "Building wheels for collected packages: func-timeout, ftfy, dm-haiku, lm-eval-harness, sqlitedict, pycountry, pathtools, subprocess32, gpustat\n",
      "  Building wheel for func-timeout (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "  Created wheel for func-timeout: filename=func_timeout-4.3.5-cp37-none-any.whl size=15097 sha256=4cac14d160ec61e4a096e609113e6396cc5864a3342fc4d8c2c90e6b5ca9e48d\n",
      "  Stored in directory: /root/.cache/pip/wheels/46/7c/4f/24f1d2d5bbff92219debe7ea19af84f76ddeb90dd4ec544f26\n",
      "  Building wheel for ftfy (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "  Created wheel for ftfy: filename=ftfy-6.0.3-cp37-none-any.whl size=41935 sha256=16a45289c44008e2f575fa886e24361690e3238aae812840d9027fde52b19e36\n",
      "  Stored in directory: /root/.cache/pip/wheels/99/2c/e6/109c8a28fef7a443f67ba58df21fe1d0067ac3322e75e6b0b7\n",
      "  Building wheel for dm-haiku (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "  Created wheel for dm-haiku: filename=dm_haiku-0.0.5.dev0-cp37-none-any.whl size=558211 sha256=55ddf8b30b5d51cc31684835a0f8283bd545c3cd90310df9d0bc03b657ca487d\n",
      "  Stored in directory: /tmp/pip-ephem-wheel-cache-mkgd1im2/wheels/97/0f/e9/17f34e377f8d4060fa88a7e82bee5d8afbf7972384768a5499\n",
      "  Building wheel for lm-eval-harness (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "  Created wheel for lm-eval-harness: filename=lm_eval_harness-0.0.1-cp37-none-any.whl size=96331 sha256=e2e144196bad68df5d022a1e16676fd2f0cc18e865973ec09b90f4469880cbfa\n",
      "  Stored in directory: /tmp/pip-ephem-wheel-cache-mkgd1im2/wheels/a8/db/b4/32ca8efd6b64f9187fefbabd636d7c94cd2150a657262ee22a\n",
      "  Building wheel for sqlitedict (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "  Created wheel for sqlitedict: filename=sqlitedict-1.6.0-cp37-none-any.whl size=14714 sha256=af52aa57cead6fde363f9f52a4ca12b4014438bada7e37f299f49f97445e853b\n",
      "  Stored in directory: /root/.cache/pip/wheels/bd/57/d3/907c3ee02d35e66f674ad0106e61f06eeeb98f6ee66a6cc3fe\n",
      "  Building wheel for pycountry (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "  Created wheel for pycountry: filename=pycountry-20.7.3-py2.py3-none-any.whl size=10746883 sha256=fff18e590a27afa318d4c31ee7098a15916a59b00bb5ca7040e5c500b1336d7c\n",
      "  Stored in directory: /root/.cache/pip/wheels/33/4e/a6/be297e6b83567e537bed9df4a93f8590ec01c1acfbcd405348\n",
      "  Building wheel for pathtools (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "  Created wheel for pathtools: filename=pathtools-0.1.2-cp37-none-any.whl size=8807 sha256=45a2c0fca698059654a7c1feb9e1f9fd74247cbd15ea3a00fae83f1b22175fc7\n",
      "  Stored in directory: /root/.cache/pip/wheels/0b/04/79/c3b0c3a0266a3cb4376da31e5bfe8bba0c489246968a68e843\n",
      "  Building wheel for subprocess32 (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "  Created wheel for subprocess32: filename=subprocess32-3.5.4-cp37-none-any.whl size=6502 sha256=d3f4fc7a51cca5a8b462fc6826afef60c617fcf0e015b2ed027d320dea68741c\n",
      "  Stored in directory: /root/.cache/pip/wheels/68/39/1a/5e402bdfdf004af1786c8b853fd92f8c4a04f22aad179654d1\n",
      "  Building wheel for gpustat (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "  Created wheel for gpustat: filename=gpustat-0.6.0-cp37-none-any.whl size=12621 sha256=3f9b348ee23b2452005787b00eb878cfd1be319e76c53f482811b6d75475761e\n",
      "  Stored in directory: /root/.cache/pip/wheels/48/b4/d5/fb5b7f1d040f2ff20687e3bad6867d63155dbde5a7c10f4293\n",
      "Successfully built func-timeout ftfy dm-haiku lm-eval-harness sqlitedict pycountry pathtools subprocess32 gpustat\n",
      "\u001B[31mERROR: tensorflow 2.5.0 has requirement gast==0.4.0, but you'll have gast 0.3.3 which is incompatible.\u001B[0m\n",
      "\u001B[31mERROR: tensorflow 2.5.0 has requirement h5py~=3.1.0, but you'll have h5py 2.10.0 which is incompatible.\u001B[0m\n",
      "\u001B[31mERROR: tensorflow 2.5.0 has requirement tensorflow-estimator<2.6.0,>=2.5.0rc0, but you'll have tensorflow-estimator 2.4.0 which is incompatible.\u001B[0m\n",
      "\u001B[31mERROR: google-colab 1.0.0 has requirement requests~=2.23.0, but you'll have requests 2.25.1 which is incompatible.\u001B[0m\n",
      "\u001B[31mERROR: google-cloud-bigquery 1.21.0 has requirement google-resumable-media!=0.4.0,<0.5.0dev,>=0.3.1, but you'll have google-resumable-media 1.3.0 which is incompatible.\u001B[0m\n",
      "\u001B[31mERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible.\u001B[0m\n",
      "\u001B[31mERROR: tensorflow-cpu 2.4.1 has requirement grpcio~=1.32.0, but you'll have grpcio 1.34.1 which is incompatible.\u001B[0m\n",
      "\u001B[31mERROR: black 20.8b1 has requirement regex>=2020.1.8, but you'll have regex 2019.12.20 which is incompatible.\u001B[0m\n",
      "Installing collected packages: tokenizers, tqdm, sacremoses, requests, transformers, setuptools, configparser, docker-pycreds, sentry-sdk, smmap, gitdb, GitPython, shortuuid, pathtools, subprocess32, wandb, einops, invoke, cryptography, pynacl, bcrypt, paramiko, pathlib2, fabric, chex, optax, multidict, yarl, async-timeout, aiohttp, colorful, blessings, gpustat, aiohttp-cors, redis, colorama, hiredis, aioredis, py-spy, opencensus-context, opencensus, ray, gast, h5py, tensorflow-estimator, tensorflow-cpu, google-crc32c, google-resumable-media, google-cloud-core, google-cloud-storage, func-timeout, ftfy, jmp, dm-haiku, typed-ast, mypy-extensions, pathspec, black, rehash, best-download, xxhash, fsspec, huggingface-hub, datasets, threadpoolctl, scikit-learn, sqlitedict, tcolorpy, msgfy, mbstrdecoder, typepy, DataProperty, pathvalidate, tabledata, pytablewriter, portalocker, sacrebleu, pycountry, numexpr, zstandard, jsonlines, ujson, lm-dataformat, pluggy, pytest, pybind11, tqdm-multiprocess, lm-eval-harness\n",
      "  Found existing installation: tqdm 4.41.1\n",
      "    Uninstalling tqdm-4.41.1:\n",
      "      Successfully uninstalled tqdm-4.41.1\n",
      "  Found existing installation: requests 2.23.0\n",
      "    Uninstalling requests-2.23.0:\n",
      "      Successfully uninstalled requests-2.23.0\n",
      "  Found existing installation: setuptools 57.0.0\n",
      "    Uninstalling setuptools-57.0.0:\n",
      "      Successfully uninstalled setuptools-57.0.0\n",
      "  Found existing installation: gast 0.4.0\n",
      "    Uninstalling gast-0.4.0:\n",
      "      Successfully uninstalled gast-0.4.0\n",
      "  Found existing installation: h5py 3.1.0\n",
      "    Uninstalling h5py-3.1.0:\n",
      "      Successfully uninstalled h5py-3.1.0\n",
      "  Found existing installation: tensorflow-estimator 2.5.0\n",
      "    Uninstalling tensorflow-estimator-2.5.0:\n",
      "      Successfully uninstalled tensorflow-estimator-2.5.0\n",
      "  Found existing installation: google-resumable-media 0.4.1\n",
      "    Uninstalling google-resumable-media-0.4.1:\n",
      "      Successfully uninstalled google-resumable-media-0.4.1\n",
      "  Found existing installation: google-cloud-core 1.0.3\n",
      "    Uninstalling google-cloud-core-1.0.3:\n",
      "      Successfully uninstalled google-cloud-core-1.0.3\n",
      "  Found existing installation: google-cloud-storage 1.18.1\n",
      "    Uninstalling google-cloud-storage-1.18.1:\n",
      "      Successfully uninstalled google-cloud-storage-1.18.1\n",
      "  Found existing installation: scikit-learn 0.22.2.post1\n",
      "    Uninstalling scikit-learn-0.22.2.post1:\n",
      "      Successfully uninstalled scikit-learn-0.22.2.post1\n",
      "  Found existing installation: numexpr 2.7.3\n",
      "    Uninstalling numexpr-2.7.3:\n",
      "      Successfully uninstalled numexpr-2.7.3\n",
      "  Found existing installation: pluggy 0.7.1\n",
      "    Uninstalling pluggy-0.7.1:\n",
      "      Successfully uninstalled pluggy-0.7.1\n",
      "  Found existing installation: pytest 3.6.4\n",
      "    Uninstalling pytest-3.6.4:\n",
      "      Successfully uninstalled pytest-3.6.4\n",
      "Successfully installed DataProperty-0.50.1 GitPython-3.1.17 aiohttp-3.7.4.post0 aiohttp-cors-0.7.0 aioredis-1.3.1 async-timeout-3.0.1 bcrypt-3.2.0 best-download-0.0.6 black-20.8b1 blessings-1.7 chex-0.0.7 colorama-0.4.4 colorful-0.5.4 configparser-5.0.2 cryptography-3.4.7 datasets-1.8.0 dm-haiku-0.0.5.dev0 docker-pycreds-0.4.0 einops-0.3.0 fabric-2.6.0 fsspec-2021.6.0 ftfy-6.0.3 func-timeout-4.3.5 gast-0.3.3 gitdb-4.0.7 google-cloud-core-1.6.0 google-cloud-storage-1.36.2 google-crc32c-1.1.2 google-resumable-media-1.3.0 gpustat-0.6.0 h5py-2.10.0 hiredis-2.0.0 huggingface-hub-0.0.10 invoke-1.5.0 jmp-0.0.2 jsonlines-2.0.0 lm-dataformat-0.0.19 lm-eval-harness-0.0.1 mbstrdecoder-1.0.1 msgfy-0.1.0 multidict-5.1.0 mypy-extensions-0.4.3 numexpr-2.7.2 opencensus-0.7.13 opencensus-context-0.1.2 optax-0.0.6 paramiko-2.7.2 pathlib2-2.3.5 pathspec-0.8.1 pathtools-0.1.2 pathvalidate-2.4.1 pluggy-0.13.1 portalocker-2.3.0 py-spy-0.3.7 pybind11-2.6.2 pycountry-20.7.3 pynacl-1.4.0 pytablewriter-0.58.0 pytest-6.2.3 ray-1.2.0 redis-3.5.3 rehash-1.0.0 requests-2.25.1 sacrebleu-1.5.0 sacremoses-0.0.45 scikit-learn-0.24.2 sentry-sdk-1.1.0 setuptools-51.3.3 shortuuid-1.0.1 smmap-4.0.0 sqlitedict-1.6.0 subprocess32-3.5.4 tabledata-1.1.4 tcolorpy-0.0.9 tensorflow-cpu-2.4.1 tensorflow-estimator-2.4.0 threadpoolctl-2.1.0 tokenizers-0.10.3 tqdm-4.45.0 tqdm-multiprocess-0.0.11 transformers-4.4.2 typed-ast-1.4.3 typepy-1.1.5 ujson-4.0.2 wandb-0.10.31 xxhash-2.0.2 yarl-1.6.3 zstandard-0.15.2\n"
     ],
     "name": "stdout"
    },
    {
     "output_type": "display_data",
     "data": {
      "application/vnd.colab-display-data+json": {
       "pip_warning": {
        "packages": [
         "google",
         "pkg_resources"
        ]
       }
      }
     },
     "metadata": {
      "tags": []
     }
    },
    {
     "output_type": "stream",
     "text": [
      "Processing ./mesh-transformer-jax\n",
      "Collecting jax==0.2.12\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/9a/67/d1a9c94104c559b49bbcb72e9efc33859e982d741ea4902d2a00e66e09d9/jax-0.2.12.tar.gz (590kB)\n",
      "\u001B[K     |████████████████████████████████| 593kB 2.2MB/s \n",
      "\u001B[?25hRequirement already satisfied: numpy>=1.12 in /usr/local/lib/python3.7/dist-packages (from jax==0.2.12) (1.19.5)\n",
      "Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax==0.2.12) (0.12.0)\n",
      "Requirement already satisfied: opt_einsum in /usr/local/lib/python3.7/dist-packages (from jax==0.2.12) (3.3.0)\n",
      "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax==0.2.12) (1.15.0)\n",
      "Building wheels for collected packages: jax, mesh-transformer\n",
      "  Building wheel for jax (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "  Created wheel for jax: filename=jax-0.2.12-cp37-none-any.whl size=682484 sha256=4636fb45fcbebf88d4dd007e59b03f69fb647920499329b29f883e6ee8a9eccb\n",
      "  Stored in directory: /root/.cache/pip/wheels/cf/00/88/75c2043dff473f58e892c7e6adfd2c44ccefb6111fcc021e5b\n",
      "  Building wheel for mesh-transformer (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "  Created wheel for mesh-transformer: filename=mesh_transformer-0.0.0-cp37-none-any.whl size=15682 sha256=385ec067463f1689fdffb1b0fd3f98ca40aa878ce5ff2187bf743675ba91190e\n",
      "  Stored in directory: /root/.cache/pip/wheels/de/a9/d2/2be3e25299342b60fca7965d4e416264ff8b6d8a7e8def76da\n",
      "Successfully built jax mesh-transformer\n",
      "Installing collected packages: jax, mesh-transformer\n",
      "  Found existing installation: jax 0.2.13\n",
      "    Uninstalling jax-0.2.13:\n",
      "      Successfully uninstalled jax-0.2.13\n",
      "Successfully installed jax-0.2.12 mesh-transformer-0.0.0\n"
     ],
     "name": "stdout"
    }
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aO1UXepF-0Uq"
   },
   "source": [
    "## Setup Model\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "ex0qJgaueZtJ"
   },
   "source": [
    "import os\n",
    "import requests \n",
    "from jax.config import config\n",
    "\n",
    "colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]\n",
    "url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1_dev20210607'\n",
    "requests.post(url)\n",
    "\n",
    "# The following is required to use TPU Driver as JAX's backend.\n",
    "config.FLAGS.jax_xla_backend = \"tpu_driver\"\n",
    "config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']"
   ],
   "execution_count": 4,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NIgUVdFLe4A8"
   },
   "source": [
    "Sometimes the next step errors for some reason, just run it again ¯\\\\\\_(ツ)\\_/¯"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "-A5IGYSaeze3"
   },
   "source": [
    "import time\n",
    "\n",
    "import jax\n",
    "from jax.experimental import maps\n",
    "import numpy as np\n",
    "import optax\n",
    "import transformers\n",
    "\n",
    "from mesh_transformer.checkpoint import read_ckpt_lowmem\n",
    "from mesh_transformer.sampling import nucleaus_sample\n",
    "from mesh_transformer.transformer_shard import CausalTransformer"
   ],
   "execution_count": 5,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "QAgKq-X2kmba",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 167,
     "referenced_widgets": [
      "5b8a31b3b4034116af0a81d897c3123b",
      "835ee2070c05443c9fa3ae6fece3fb44",
      "9e6646f7066341ef9619f5f7e6f7c8b8",
      "4f803e3d56b74e18997b216bf9d65337",
      "f9fa7c1bc9f04df09e487b16761d6b8f",
      "39aed05f694a495abd03aa00ff71efb3",
      "61b0857802184b51abf77039fad681d7",
      "0fb689df0a8b42c2b87e7f11b818b88a",
      "8b76924bdac749d5b4908678fd0a7497",
      "b03bd083afc84f52834d4a124e57c442",
      "3ba9fe91731f4ff380cda2183ad72c4d",
      "4e01b9cceffb4d178e97f1d1b36d62b0",
      "b3ac7912f51746afae20616269dc69f0",
      "fb1263f5a3da42e6acab0e9a66551b32",
      "3426927332884df4b1ec5544ed479c45",
      "debf4a0a609a4cffa5cd686a0a23f6e2",
      "7197de8a39ab4e4186eabd061c730cf5",
      "c31661c9a1ff4bd681cc10d70fd287c0",
      "bb11a4fbf16f488082aa2bac7443fd62",
      "5360630d6b0541dbaafa7ff91610bd84",
      "f3883c133c62438db114c22cfc94e351",
      "74672c43909444c592955afd0d727a28",
      "67f7c3bc59384dc7812d7130c7cde450",
      "bb7410ff9cb4484da73302993427b90a"
     ]
    },
    "outputId": "041117df-a315-4b95-9caa-26cb870ff3df"
   },
   "source": [
    "params = {\n",
    "  \"layers\": 28,\n",
    "  \"d_model\": 4096,\n",
    "  \"n_heads\": 16,\n",
    "  \"n_vocab\": 50400,\n",
    "  \"norm\": \"layernorm\",\n",
    "  \"pe\": \"rotary\",\n",
    "  \"pe_rotary_dims\": 64,\n",
    "\n",
    "  \"seq\": 2048,\n",
    "  \"cores_per_replica\": 8,\n",
    "  \"per_replica_batch\": 1,\n",
    "}\n",
    "\n",
    "per_replica_batch = params[\"per_replica_batch\"]\n",
    "cores_per_replica = params[\"cores_per_replica\"]\n",
    "seq = params[\"seq\"]\n",
    "\n",
    "\n",
    "params[\"sampler\"] = nucleaus_sample\n",
    "\n",
    "# here we \"remove\" the optimizer parameters from the model (as we don't need them for inference)\n",
    "params[\"optimizer\"] = optax.scale(0)\n",
    "\n",
    "mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)\n",
    "devices = np.array(jax.devices()).reshape(mesh_shape)\n",
    "\n",
    "maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))\n",
    "\n",
    "tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')"
   ],
   "execution_count": 6,
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5b8a31b3b4034116af0a81d897c3123b",
       "version_minor": 0,
       "version_major": 2
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…"
      ]
     },
     "metadata": {
      "tags": []
     }
    },
    {
     "output_type": "stream",
     "text": [
      "\n"
     ],
     "name": "stdout"
    },
    {
     "output_type": "display_data",
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8b76924bdac749d5b4908678fd0a7497",
       "version_minor": 0,
       "version_major": 2
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…"
      ]
     },
     "metadata": {
      "tags": []
     }
    },
    {
     "output_type": "stream",
     "text": [
      "\n"
     ],
     "name": "stdout"
    },
    {
     "output_type": "display_data",
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7197de8a39ab4e4186eabd061c730cf5",
       "version_minor": 0,
       "version_major": 2
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1355256.0, style=ProgressStyle(descript…"
      ]
     },
     "metadata": {
      "tags": []
     }
    },
    {
     "output_type": "stream",
     "text": [
      "\n"
     ],
     "name": "stdout"
    }
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yFgRkUgfiNdA"
   },
   "source": [
    "Here we create the network and load the parameters from the downloaded files. Expect this to take around 5 minutes."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "lwNETD2Uk8nu",
    "outputId": "a659285b-5b46-4ddf-d65c-fbdad0594b86"
   },
   "source": [
    "total_batch = per_replica_batch * jax.device_count() // cores_per_replica\n",
    "\n",
    "network = CausalTransformer(params)\n",
    "\n",
    "network.state = read_ckpt_lowmem(network.state, \"step_383500/\", devices.shape[1])\n",
    "\n",
    "network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))"
   ],
   "execution_count": 7,
   "outputs": [
    {
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py:412: UserWarning: xmap is an experimental feature and probably has bugs!\n",
      "  warn(\"xmap is an experimental feature and probably has bugs!\")\n"
     ],
     "name": "stderr"
    },
    {
     "output_type": "stream",
     "text": [
      "key shape (8, 2)\n",
      "in shape (1, 2048)\n",
      "dp 1\n",
      "mp 8\n",
      "read from gcs in 22.633s\n"
     ],
     "name": "stdout"
    }
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "A-eT7Sw6if4J"
   },
   "source": [
    "## Run Model\n",
    "\n",
    "Finally, we are ready to infer with the model! The first sample takes around a minute due to compilation, but after that it should only take about 10 seconds per sample.\n",
    "\n",
    "Feel free to mess with the different sampling parameters (top_p and temp), as well as the length of the generations (gen_len, causes a recompile when changed).\n",
    "\n",
    "You can also change other things like per_replica_batch in the previous cells to change how many generations are done in parallel. A larger batch has higher latency but higher throughput when measured in tokens generated/s. This is useful for doing things like best-of-n cherry picking.\n",
    "\n",
    "*Tip for best results: Make sure your prompt does not have any trailing spaces, which tend to confuse the model due to the BPE tokenization used during training.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# allow text wrapping in generated output: https://stackoverflow.com/a/61401455\n",
    "from IPython.display import HTML, display\n",
    "\n",
    "def set_css():\n",
    "  display(HTML('''\n",
    "  <style>\n",
    "    pre {\n",
    "        white-space: pre-wrap;\n",
    "    }\n",
    "  </style>\n",
    "  '''))\n",
    "get_ipython().events.register('pre_run_cell', set_css)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ZVzs2TYlvYeX",
    "outputId": "80c8b1e5-0d1c-4799-d682-4a5be0c038a1"
   },
   "source": [
    "def infer(context, top_p=0.9, temp=1.0, gen_len=512):\n",
    "    tokens = tokenizer.encode(context)\n",
    "\n",
    "    provided_ctx = len(tokens)\n",
    "    pad_amount = seq - provided_ctx\n",
    "\n",
    "    padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)\n",
    "    batched_tokens = np.array([padded_tokens] * total_batch)\n",
    "    length = np.ones(total_batch, dtype=np.uint32) * len(tokens)\n",
    "\n",
    "    start = time.time()\n",
    "    output = network.generate(batched_tokens, length, gen_len, {\"top_p\": np.ones(total_batch) * top_p, \"temp\": np.ones(total_batch) * temp})\n",
    "\n",
    "    samples = []\n",
    "    decoded_tokens = output[1][0]\n",
    "\n",
    "    for o in decoded_tokens[:, :, 0]:\n",
    "      samples.append(f\"\\033[1m{context}\\033[0m{tokenizer.decode(o)}\")\n",
    "\n",
    "    print(f\"completion done in {time.time() - start:06}s\")\n",
    "    return samples\n",
    "\n",
    "print(infer(\"EleutherAI is\")[0])"
   ],
   "execution_count": 8,
   "outputs": [
    {
     "output_type": "stream",
     "text": [
      "completion done in 59.615925312042236s\n",
      " a team of senior and junior developers that are working together to push back the frontiers of racing videogames. The team is involved in developing a racing videogame from the ground up. It will feature cars and tracks drawn from a creative process.\n",
      "\n",
      "ELEUTERAI TEAM\n",
      "\n",
      "HISTORY\n",
      "\n",
      "Founded by Gianfrancesco Cengia, EleuterAI is an independent videogame development team that started making videogames in 2004.\n",
      "\n",
      "A young team formed in collaboration between the principle members Gianfrancesco Cengia and Luigi Zampar (as well as Luca Barbero, Rolando Labernini and Simone Trovato). Over a period of five years, the members learned how to write code, how to build efficient software processes and how to design a fun racing game. All the while, they shared an enthusiasm for racing games and the ever-evolving technology that became a shared passion.\n",
      "\n",
      "After five years of collaboration, the team reached the starting point of its car games. Due to the great success of Formula One 2017, the videogame race team found the confidence to start a new phase of their adventure.\n",
      "\n",
      "The innovative approach of the team also brought results to the printed world. Indeed, the videogames that the EleuterAI team published in 2014, 2015 and 2016 are among the highest-selling videogames on the French market.\n",
      "\n",
      "Among the videogame developments produced by the team, Trovato mentions Firepower, Vinicius and Stop and Stares.\n",
      "\n",
      "The EleuterAI experience also showed the team the opportunity to expand its horizons outside of car racing. Indeed, as of today, the group is involved in developing a videogame based on the legendary Italian car brand Alfa Romeo. The development phase is already underway and will be a challenge and new experience for the team. The project is planned to be released in late 2019.\n",
      "\n",
      "CHALLENGE\n",
      "\n",
      "This project will be for the EleuterAI team an extraordinary opportunity to show their capabilities and the expertise acquired in the car race gaming experience. As a result, we could have a videogame of the highest quality.\n",
      "\n",
      "Team-work and collaboration will be key elements to develop a fun racing experience and will be at the center of the game.\n",
      "\n",
      "Vision & Values\n",
      "\n",
      "We are close to our mission, which is to make the best racing videogames. We don’t have any brand vision or values, but we do have a very clear idea of\n"
     ],
     "name": "stdout"
    }
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "nvlAK6RbCJYg",
    "outputId": "c0525611-ecc8-422e-d8e3-24b3a2159083"
   },
   "source": [
    "#@title  { form-width: \"300px\" }\n",
    "top_p = 0.9 #@param {type:\"slider\", min:0, max:1, step:0.1}\n",
    "temp = 1 #@param {type:\"slider\", min:0, max:1, step:0.1}\n",
    "\n",
    "context = \"\"\"In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English.\"\"\"\n",
    "\n",
    "print(infer(top_p=top_p, temp=temp, gen_len=512, context=context)[0])"
   ],
   "execution_count": 9,
   "outputs": [
    {
     "output_type": "stream",
     "text": [
      "completion done in 13.510299682617188s\n",
      " The scientists decided to ask one of the creatures to show them what English sounds like.\n",
      "\n",
      "The creature was then transported in a private jet to a real-life, windy city, where it was finally placed on the roof of a building, wearing a set of headphones. The scientist proceeded to ask the creature to demonstrate the letter “R”, which turned out to be an “R”-shaped unicorn.\n",
      "\n",
      "At this point, the scientists expected the creature to start singing, or making a horselike sound. Instead, the unicorn imitated a car horn, roaring “Raaaaaarrrr!”. The scientist chuckled at the unplanned execution. The unicorn finished the demonstration by shouting “Ahhhhhhhhhhhhhhhhh!”, sounding like an escaped convict having an orgasm.\n",
      "\n",
      "Then the creature, after spending about an hour on the rooftop, as it was exhausted, went to sleep. By chance, the creature’s body was mown down by a steamroller a few hours later, and its body was sent to the scientists.\n",
      "\n",
      "The scientists analysed the DNA and created a synthetic version of the creature’s horn, and inserted it into a New Zealand sheep. The sheep were taught to use a computer program for spelling, but after the sheep had been taught to pronounce the letter “R”, it would never stop saying it.\n",
      "\n",
      "Dr. Graeme Lloyd, of the University of Canterbury, New Zealand, said “We created this synthetic horn by the end of the project, and inserted it into the New Zealand sheep. We told the sheep that they had to use the synthetic horn to say “R”, and that they must never say anything else. For a month, the sheep would wake up at dawn, go on to the roof of our office building, and then try to say “R”. In a few weeks, the sheep’s heads would fall off.”\n",
      "\n",
      "Related\n",
      "\n",
      "Comments\n",
      "\n",
      "Before going off to sally forth about the wonders and dangers of the internet, did you perhaps check on BING to see if your best friend had posted up something about you?\n",
      "Or, perhaps it’s more likely that you’re living in a fantasy land, which, when looking at that headline, is probably the case.\n",
      "\n",
      "Actually, that article about unicorns is so unbelievably stupid and self-satifyingly absurd, even I had to\n"
     ],
     "name": "stdout"
    }
   ]
  }
 ]
}